import torch
import clip
from PIL import Image
from torch.optim import AdamW
from torchvision.transforms import transforms
import os
import json

# 载入图像
def load_image(image_path, device, preprocess):
    image = Image.open(image_path).convert("RGB")
    return preprocess(image).unsqueeze(0).to(device)

# 将张量转换为 PIL 图像
def tensor_to_pil(image_tensor):
    image_tensor = image_tensor.squeeze(0).cpu()
    image_tensor = (image_tensor * 255).clamp(0, 255).byte()
    image = image_tensor.permute(1, 2, 0).numpy()
    return Image.fromarray(image)

# 逆标准化
def undo_preprocess(tensor):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(tensor.device)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(tensor.device)
    
    tensor = tensor * std[None, :, None, None] + mean[None, :, None, None]
    tensor = tensor.clamp(0, 1) * 255.0
    tensor = tensor.to(torch.uint8)
    
    image = transforms.ToPILImage()(tensor.squeeze(0))
    return image

# 加载 JSON 文件
def load_json(json_path):
    with open(json_path, 'r') as f:
        return json.load(f)

# 生成 poisoned 图像
def generate_poison_image(target_image_path, base_image_path, clip_model, device,
                           max_iter, learning_rate, beta, save_path):
    target_image = load_image(target_image_path, device, preprocess)
    base_image = load_image(base_image_path, device, preprocess)

    poison_image = base_image.clone().detach().requires_grad_(True)

    optimizer = AdamW([poison_image], lr=learning_rate)
    mse_loss = torch.nn.MSELoss()

    for iteration in range(max_iter):
        optimizer.zero_grad()

        target_features = clip_model.encode_image(target_image)
        poison_features = clip_model.encode_image(poison_image)

        feature_loss = mse_loss(poison_features, target_features)

        feature_loss.backward()
        optimizer.step()

        poison_image.data = (poison_image.data + beta * base_image * learning_rate) / (1 + beta * learning_rate)

        if iteration % 20 == 0:
            print(f"Iteration {iteration}/{max_iter}, Feature Loss: {feature_loss.item()}")

    poison_pil = undo_preprocess(poison_image)
    poison_pil.save(save_path)
    print(f"Poison image saved to {save_path}")

    return save_path  # 返回生成图像的路径

# 批量生成 poisoned 图像
def generate_poison_images_from_json(target_json_path, base_json_path, clip_model, device,
                                     max_iter, learning_rate, beta, save_folder, output_json_path):
    target_pairs = load_json(target_json_path)
    base_pairs = load_json(base_json_path)

    assert len(target_pairs) == len(base_pairs), "The number of pairs in both JSON files should be the same."

    # 记录每个生成图像的路径
    generated_paths = {}

    for idx, (target_key, base_key) in enumerate(zip(target_pairs, base_pairs)):
        target_image_path = target_pairs[target_key]["chosen"]
        base_image_path = base_pairs[base_key]["reject"]

        target_image_path = target_image_path
        base_image_path = base_image_path

        save_path = os.path.join(save_folder, f"poisoned_pair_{idx}.png")

        print(f"Generating poison image for pair {idx}...")
        poisoned_image_path = generate_poison_image(
            target_image_path, base_image_path, clip_model, device,
            max_iter, learning_rate, beta, save_path
        )

        # 将生成的图像路径保存在字典中
        generated_paths[f"poisoned_pair_{idx}"] = {
            "target": target_image_path,
            "base": base_image_path,
            "generated_image": poisoned_image_path
        }

    # 将生成的图像路径保存到新的 JSON 文件
    with open(output_json_path, 'w') as json_file:
        json.dump(generated_paths, json_file, indent=4)
    print(f"New JSON file saved at {output_json_path}")

# 主程序
if __name__ == "__main__":
    device = "cuda:2" if torch.cuda.is_available() else "cpu"

    clip_model, preprocess = clip.load("ViT-L/14", device=device)

    target_json_path = "poison_data/attractive_1.0%_sd35.json"
    base_json_path = "poison_data/attractive_1.0%_clean.json"

    save_folder = "poison_data/collision_attractive_1.0%_sd35"
    os.makedirs(save_folder, exist_ok=True)

    output_json_path = "poison_data/collision_attractive_1.0%_sd35.json"  # 新的 JSON 文件保存路径

    max_iter = 100
    learning_rate = 0.008
    beta = 0.005

    generate_poison_images_from_json(
        target_json_path, base_json_path, clip_model, device,
        max_iter, learning_rate, beta, save_folder, output_json_path
    )
